from zlai.utils import pkg_config
if pkg_config.python_version >= (3, 11):
    from typing import Self, List, Dict, Union, Literal, Optional
else:
    from typing_extensions import Self
    from typing import List, Dict, Union, Literal, Optional
from pydantic import Field
from zlai.types.function_call import *
from zlai.types.messages.display import *
from .base import Message
from .messages import UserMessage, AssistantMessage, SystemMessage
from .function import ObservationMessage, ToolsMessage
from .image import ImageMessage
from .audio import AudioMessage
from .cite import CiteMessage


__all__ = [
    "ChatCompletionMessage",
]


TypeCompletionTransMessage = Union[
    Message,
    UserMessage,
    AssistantMessage,
    SystemMessage,
    ImageMessage,
    AudioMessage,
    ObservationMessage,
    ToolsMessage,
    CiteMessage,
]


class ChatCompletionMessage(Message):
    """"""
    role: Literal["user", "assistant", "system", "tool", "observation"] = Field(
        ..., description="""The role of the author of this message.""")
    content: Optional[Union[str, List[Dict]]] = Field(default=None, description="""The contents of the message.""")
    function_call: Optional[FunctionCall] = Field(default=None, description="""
        Deprecated and replaced by `tool_calls`.
        The name and arguments of a function that should be called, as generated by the
        model.
    """)
    name: Optional[str] = Field(default=None, description="""The name of the function to call.""")
    tool_call_id: Optional[str] = Field(default=None, description="""Tool call that this message is responding to.""")
    tool_calls: Optional[List[ChatCompletionMessageToolCall]] = Field(default=None, description="""
        The tool calls generated by the model, such as function calls.
    """)

    def _is_audio_message(self) -> bool:
        """"""
        mark = False
        if isinstance(self.content, list):
            for _content in self.content:
                if isinstance(_content, dict):
                    if _content.get("type") == "audio":
                        mark = True
                        return mark
        return mark

    def _is_image_message(self) -> bool:
        """"""
        mark = False
        if isinstance(self.content, list):
            for _content in self.content:
                if isinstance(_content, dict):
                    if _content.get("type") == "image_url":
                        mark = True
                        return mark
        return mark

    def _is_cite_message(self) -> bool:
        """"""
        mark = False
        if isinstance(self.content, list):
            for _content in self.content:
                if isinstance(_content, dict):
                    if _content.get("type") == "cite":
                        mark = True
                        return mark
        return mark

    def to_base_message(self) -> Message:
        """"""
        return Message.model_validate(self.model_dump())

    def to_user_message(self) -> UserMessage:
        """"""
        return UserMessage(content=self.content)

    def to_assistant_message(self) -> AssistantMessage:
        """"""
        return AssistantMessage(content=self.content)

    def to_system_message(self) -> SystemMessage:
        """"""
        return SystemMessage(content=self.content)

    def to_audio_message(self) -> AudioMessage:
        """"""
        return AudioMessage.model_validate(self.model_dump())

    def to_image_message(self) -> ImageMessage:
        """"""
        return ImageMessage.model_validate(self.model_dump())

    def to_message(self) -> Union[Self, TypeCompletionTransMessage]:
        """"""
        if self.tool_calls is not None or self.function_call is not None:
            return self
        if self.tool_call_id is not None and self.name is not None:
            message = ToolsMessage.model_validate(self.model_dump())
        elif self._is_audio_message():
            message = self.to_audio_message()
        elif self._is_image_message():
            message = self.to_image_message()
        elif self.role == "observation":
            message = ObservationMessage.model_validate(self.model_dump())
        elif self._is_cite_message():
            message = CiteMessage.model_validate(self.model_dump())
        else:
            message = self.to_base_message()
        return message

    def to_dict(self) -> Dict:
        """"""
        return self.model_dump()

    def _show_tool_message(self):
        """
        :return:
        """
        _ = self._validate_streamlit()
        show_tool_call(tool_call=self.tool_calls[0])

    def show_streamlit(self):
        """"""
        if self.tool_calls is not None or self.function_call is not None:
            self._show_tool_message()
        else:
            _message = self.to_message()
            _message.show_streamlit()
